import time, copy, os
import warnings
from typing import Any, Callable, Dict, List, Optional, Union
from array2gif import write_gif

import gymnasium as gym
import numpy as np
import torch
from tianshou.data import (
    Batch,
    CachedReplayBuffer,
    ReplayBuffer,
    ReplayBufferManager,
    VectorReplayBuffer,
    to_numpy,
)
import cv2
import imageio as imio
from tianshou.data import Collector
from tianshou.env import BaseVectorEnv

from Causal import Dynamics
from Policy.goal_policy import GoalPolicy
from State.buffer import VectorGCReplayBufferManager
from State.her_buffer import add_her_trajectory, her_summary_statistics
from State.utils import compute_proximity, ObjDict


def construct_her_obs(achieved_goal, desired_goal, env_obs, reached_graph_counter):
    return Batch(achieved_goal=achieved_goal,
                 desired_goal=desired_goal,
                 # if env is goal-based, only upper policy needs to consider the goal, and lower policy doesn't need to
                 observation=env_obs.observation if isinstance(env_obs, Batch) else env_obs,
                 reached_graph_counter=reached_graph_counter)


class GCCollector(Collector):
    """
    Modified collector for Goal Conditioned RL policies. In particular, this method \
    records additional information into the replay buffer. Changed lines identified \
    with ###EDIT###

    Collector enables the policy to interact with different types of envs with \
    exact number of steps or episodes.

    :param policy: an instance of the :class:`~tianshou.policy.BasePolicy` class.
    :param env: a ``gym.Env`` environment or an instance of the
        :class:`~tianshou.env.BaseVectorEnv` class.
    :param buffer: an instance of the :class:`~tianshou.data.ReplayBuffer` class.
        If set to None, it will not store the data. Default to None.
    :param function preprocess_fn: a function called before the data has been added to
        the buffer, see issue #42 and :ref:`preprocess_fn`. Default to None.
    :param bool exploration_noise: determine whether the action needs to be modified
        with corresponding policy's exploration noise. If so, "policy.
        exploration_noise(act, batch)" will be called automatically to add the
        exploration noise into action. Default to False.

    The "preprocess_fn" is a function called before the data has been added to the
    buffer with batch format. It will receive only "obs" and "env_id" when the
    collector resets the environment, and will receive the keys "obs_next", "rew",
    "terminated", "truncated, "info", "policy" and "env_id" in a normal env step.
    Alternatively, it may also accept the keys "obs_next", "rew", "done", "info",
    "policy" and "env_id".
    It returns either a dict or a :class:`~tianshou.data.Batch` with the modified
    keys and values. Examples are in "test/base/test_collector.py".

    .. note::

        Please make sure the given environment has a time limitation if using n_episode
        collect option.

    .. note::

        In past versions of Tianshou, the replay buffer that was passed to `__init__`
        was automatically reset. This is not done in the current implementation.
    """

    def __init__(
        self,
        policy: GoalPolicy,
        dynamics: Dynamics,
        env: Union[gym.Env, BaseVectorEnv],
        normalize_goal: Callable = None,
        normalize_obs: Callable = None,
        buffer: Optional[VectorGCReplayBufferManager] = None,
        timeout: int = 30,
        preprocess_fn: Optional[Callable[..., Batch]] = None,
        exploration_noise: bool = False,
        name: str = "",
        root_dir: str = "",
        save_gif_num: bool = False,
        num_factors: int = 30,
        render: bool = False,
        separate_her: bool = False,
        num_her_resamples: int = 1,
        her_traj_length: int = -1,
        use_lowest_post: bool = False,
        goal_conditioned: bool = False
    ) -> None:
        self.name = name
        self.dynamics = dynamics  # this is needed for reset
        self.buffer = buffer
        self.timeout = float(timeout)
        self.exp_path = root_dir
        self.save_gif_num = save_gif_num
        self.render = render
        self.epoch_num = 0
        self.her_traj_length = her_traj_length
        self.normalize_goal = normalize_goal
        self.normalize_obs = normalize_obs
        self.use_lowest_post = use_lowest_post
        self.goal_conditioned = goal_conditioned


        super().__init__(policy, env, buffer=buffer, exploration_noise=exploration_noise, preprocess_fn=preprocess_fn)

        self.random_all = False

        # extra logging metrics
        self.global_logging_metrics = Batch(num_reset=[0])
        self.epi_logging_metrics = Batch(rew=np.zeros(self.env_num),
                                         success=np.zeros(self.env_num, dtype=bool),
                                         reached_graph=np.zeros(self.env_num, dtype=bool),
                                         reached_goal=np.zeros(self.env_num, dtype=bool),
                                         achieve_graph=np.zeros((self.env_num, self.policy.num_factors, self.policy.num_factors + 1), dtype=bool),
                                         updated=np.zeros(self.env_num, dtype=bool))
        self.graph_to_count_idx = np.power(2, np.arange(num_factors + 1)).astype(int)
        
        # gathers statistics on distance between achieved and desired goals
        # self.achieved_desired_records = "./data/small_hard/" # TODO: hardcoded right now dealine with either empty string or save location
        self.achieved_desired_records = "" # TODO: hardcoded right now dealine with either empty string or save location
        if len(self.achieved_desired_records) > 0:
            try: os.makedirs(self.achieved_desired_records)
            except OSError as e: pass
        self.achieved_desired_queue = list()
        self.achieved_desired_reset_counter = 0

        self.separate_her = separate_her
        # self.preprocess_fn = preprocess_fn
        self.num_her_resamples = num_her_resamples
        self.temp_her_buffers = list()
        self.temp_her_buffer_ptrs = list()
    
    def set_her_trajectory_check(self, her_trajectory_check: Callable):
        self.her_trajectory_check = her_trajectory_check
    
    # def create_new_batch(self, batch):
    #     # something about deepcopy called on batch appears to fail
    #     new_batch = Batch()
    #     for k in batch.keys():
    #         if type(batch[k]) == Batch:
    #             new_batch[k] = self.create_new_batch(batch[k])
    #         else:
    #             new_batch[k] = copy.deepcopy(batch[k])
    #     return new_batch
    
    def add_to_temp_buffer(self, data: Batch):
        if len(self.temp_her_buffers) == 0:
            self.temp_her_buffers = [Batch.cat([copy.deepcopy(data[i:i+1]) for _ in range(int(self.timeout))]) for i in range(self.env_num)]
            # print(self.timeout)
            # self.temp_her_buffers = [Batch(temp_her_ep_buffer, copy=True) for _ in range(self.env_num)]
            self.temp_her_buffer_ptrs = np.ones(self.env_num).astype(int)
        else:
            for i in range(self.env_num):
                self.temp_her_buffers[i][self.temp_her_buffer_ptrs[i]] = self.data[i]
                # print("assignment", type(self.temp_her_buffers), i, self.temp_her_buffer_ptrs, self.temp_her_buffers[0][0].obs.achieved_goal)
            # print(self.temp_her_buffer_ptrs[0], self.temp_her_buffers[0][self.temp_her_buffer_ptrs[0]].obs.achieved_goal, self.data[0].obs.achieved_goal)
            # print(self.temp_her_buffers[0][:self.temp_her_buffer_ptrs[0]].obs.achieved_goal)
            self.temp_her_buffer_ptrs += 1

    def _reset_env_with_ids(
        self,
        local_ids: Union[List[int], np.ndarray],
        global_ids: Union[List[int], np.ndarray],
        gym_reset_kwargs_list: Optional[List[Dict[str, Any]]] = None,
        gym_reset_kwargs: Optional[Dict[str, Any]] = None,
    ) -> None:
        # handles the reset of environments based on particular ids
        gym_reset_kwargs = gym_reset_kwargs if gym_reset_kwargs else {}
        # vectorized environment handles individual reset functions
        if self.goal_conditioned: obs_reset, info = self.env.reset(global_ids, gym_reset_kwargs_list, **gym_reset_kwargs)
        else: obs_reset, info = self.env.reset(global_ids, **gym_reset_kwargs)
        if self.preprocess_fn:
            processed_data = self.preprocess_fn(
                obs=obs_reset, info=info, env_id=global_ids
            )
            obs_reset = processed_data.get("obs", obs_reset)
            info = processed_data.get("info", info)
        self.data.info[local_ids] = info

        # adapted below this line
        self.data.time[local_ids] = 0

        # just a placeholder, achieved and desired target goals should be reassigned later
        self.data.obs_next[local_ids] = self.normalize_batch(construct_her_obs(achieved_goal=self.data.info.achieved_goal[local_ids],
                                                          desired_goal=self.data.info.desired_goal[local_ids],
                                                          env_obs=obs_reset,
                                                          reached_graph_counter=np.zeros(len(local_ids)),))

    def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
        # resets all environments
        super().reset_env(gym_reset_kwargs=gym_reset_kwargs)
        
        self.data.obs = self.normalize_batch(construct_her_obs(achieved_goal=self.data.info.achieved_goal,
                                                desired_goal=self.data.info.desired_goal,
                                                env_obs=self.data.obs,
                                                reached_graph_counter=np.zeros(len(self.env)),))
        
        self.goals = self.data.obs.desired_goal

        # target is the state of the goal variable
        self.data.update(
            time=np.zeros(len(self.env)),
            target=self.policy.get_target(self.data),         # used by history policy
        )


    def record_reset_logging_metrics(self, gc_metrics, local_ids, global_ids):
        # TODO: any more global logging metrics should go here
        # The gc_metrics is updated for every value
        num_reset = len(global_ids)
        gc_metrics.num_reset += [num_reset]
        return gc_metrics

    def record_reset_all_logging_metrics(self, epi_metrics, local_ids, global_ids):
        # log all metrics that are per-environment (episode) dependent
        num_reset = len(global_ids)
        for k in epi_metrics.keys():
            epi_metrics[k].extend(self.epi_logging_metrics.get(k)[global_ids])

        # goal conditioned metrics have specific keys
        # update stats will update the achieved graph information
        self.policy.update_stats(self.epi_logging_metrics[global_ids], self.data.obs.desired_goal[local_ids])

        # reset the appropriate indices
        self.epi_logging_metrics[global_ids] = Batch(rew=np.zeros(num_reset),
                                                     success=np.zeros(num_reset, dtype=bool),
                                                     reached_graph=np.zeros(num_reset, dtype=bool),
                                                     reached_goal=np.zeros(num_reset, dtype=bool),
                                                     achieve_graph=np.zeros((num_reset, self.policy.num_factors, self.policy.num_factors + 1), dtype=int),
                                                     updated=np.zeros(num_reset, dtype=bool))

        return epi_metrics

    def update_logging_metrics(self, data, global_ids):
        epi_metric_stats = Batch(rew=data.env_rew,
                                 success=data.info.get("success", np.zeros_like(data.env_rew, dtype=bool)),
                                 reached_graph=data.reached.reached_graph.astype(bool),
                                 reached_goal=data.reached.reached_goal.astype(bool),
                                 achieve_graph=data.graph)
        # TODO: might be broken
        for logging, stats in zip([self.epi_logging_metrics],
                                  [epi_metric_stats]):
            for k, v in logging.items():
                if k not in stats:
                    continue
                assert isinstance(v, np.ndarray)
                if v.dtype == bool:
                    # success type: use np.logical_or to upate
                    v[global_ids] = v[global_ids] | stats[k]
                elif isinstance(v[0], np.floating):  # check if v.dtype is float16/32/64
                    # reward type: use np.add to update
                    v[global_ids] = v[global_ids] + stats[k]
                else:
                    raise NotImplementedError
        
            if "achieve_graph" in logging:
                logging["achieve_graph"][global_ids] = logging["achieve_graph"][global_ids].astype(bool) | data.true_graph.astype(bool)

    def update_her_add_statistics(self, old_statistics, new_statistics):
        for k in ["num_ended","reached", "num_traj", "graph_totals", "true_graph_totals", "true_reached", "match_true", "false_positive", "false_negative", "total_positive", "total_negative", "achieved_desired"]:
            if "her/" + k in old_statistics: old_statistics["her/" + k] += new_statistics["her/" + k]
            else: old_statistics["her/" + k] = new_statistics["her/" + k]
        self.achieved_desired_queue += new_statistics["her/achieved_desired"]
        if len(self.achieved_desired_queue) > 10000:
            if len(self.achieved_desired_records) > 0:
                np.save(os.path.join(self.achieved_desired_records, "achieved_desired" + str(self.achieved_desired_reset_counter) + ".npy"), np.array(self.achieved_desired_queue))
            self.achieved_desired_queue = list()
            self.achieved_desired_reset_counter += 1
        return old_statistics

    def demonstrate_override(self, img):
        # TODO: can only demonstrate 2D actions
        for i, iv in enumerate(img[1:]):
            cv2.imshow("demo_frame"+str(i), iv)
            k = cv2.waitKey(1)
        cv2.imshow("demo_frame", img[0])
        k = cv2.waitKey(100)
        if k == ord('a'):
            action = np.array([0,-1])
        elif k ==ord("d"):
            action = np.array([0,1])
        elif k == ord('w'):
            action = np.array([-1,0])
        elif k ==ord("s"):
            action = np.array([1,0])
        else:
            action = np.array([0,0])
        return np.stack([action.copy() for _ in range(self.env_num)])
    
    def normalize_batch(self, batch_obs):
        batch_obs.achieved_goal = self.normalize_goal(batch_obs.achieved_goal)
        batch_obs.desired_goal = self.normalize_goal(batch_obs.desired_goal)
        batch_obs.observation = self.normalize_obs(batch_obs.observation)
        return batch_obs

    def collect(
        self,
        n_step: Optional[int] = None,
        n_episode: Optional[int] = None,
        random: bool = False,
        n_steps_per_goal: Optional[int] = 1,
        render: Optional[float] = None,
        no_grad: bool = True,
        goal_conditioned = False,
        gym_reset_kwargs: Optional[Dict[str, Any]] = None,
        demonstrate: bool = False,
        show_frame: bool= False,
        save_frame_dir: str= "", 
    ) -> Dict[str, Any]:
        """Collect a specified number of step or episode.

        To ensure unbiased sampling result with n_episode option, this function will
        first collect ``n_episode - env_num`` episodes, then for the last ``env_num``
        episodes, they will be collected evenly from each env.

        :param int n_step: how many steps you want to collect.
        :param int n_episode: how many episodes you want to collect.
        :param bool random: whether to use random policy for collecting data. Default
            to False.
        :param int n_traj_per_goal: number of trajectories per goal. Default to 1.
        :param float render: the sleep time between rendering consecutive frames.
            Default to None (no rendering).
        :param bool no_grad: whether to retain gradient in policy.forward(). Default to
            True (no gradient retaining).
        :param gym_reset_kwargs: extra keyword arguments to pass into the environment's
            reset function. Defaults to None (extra keyword arguments)
        :param show_frame: visualizes frames through cv2
        :param save_frame_dir: saves frames in the particular folder, where each thread goes to a different folder

        .. note::

            One and only one collection number specification is permitted, either
            ``n_step`` or ``n_episode``.

        :return: A dict including the following keys

            * ``n/ep`` collected number of episodes.
            * ``n/st`` collected number of steps.
            * ``rews`` array of episode reward over collected episodes.
            * ``lens`` array of episode length over collected episodes.
            * ``idxs`` array of episode start index in buffer over collected episodes.
            * ``rew`` mean of episodic rewards.
            * ``len`` mean of episodic lengths.
            * ``rew_std`` standard error of episodic rewards.
            * ``len_std`` standard error of episodic lengths.
        """
        assert not self.env.is_async, "Please use AsyncCollector if using async venv."
        if n_step is not None:
            assert n_episode is None, (
                f"Only one of n_step or n_episode is allowed in Collector."
                f"collect, got n_step={n_step}, n_episode={n_episode}."
            )
            assert n_step > 0
            if not n_step % self.env_num == 0:
                warnings.warn(
                    f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
                    "which may cause extra transitions collected into the buffer."
                )
            ready_env_ids = np.arange(self.env_num)
        elif n_episode is not None:
            assert n_episode > 0
            ready_env_ids = np.arange(min(self.env_num, n_episode))
            self.data = self.data[:min(self.env_num, n_episode)]
        else:
            raise TypeError(
                "Please specify at least one (either n_step or n_episode) "
                "in AsyncCollector.collect()."
            )

        start_time = time.time()

        step_count = 0
        per_goal_step_count = [0] * self.env_num
        episode_count = 0
        first_trajectory = True
        episode_rews = []
        episode_lens = []
        episode_start_indices = []
        her_add_statistics = dict()

        # extra logging metrics
        epi_metrics = ObjDict({k: [] for k in self.epi_logging_metrics.keys()})
        global_metrics = ObjDict({k: [] for k in self.global_logging_metrics.keys()})

        random_all = random | self.random_all
        random = random_all

        ###GIF EDIT###
        # First, test whether we can save gif from mini-behavior
        save_gif = self.save_gif_num and not self.policy.training
        num_of_gifs = min(self.save_gif_num, len(ready_env_ids))
        pause = 0.05

        if save_gif:
            frames = [[] for _ in range(num_of_gifs)]

        if goal_conditioned and first_trajectory:
            # Reset the environment
            gym_reset_kwargs = {'options': {'goal': None}}
            self.reset_env(gym_reset_kwargs=gym_reset_kwargs)
            first_trajectory = False

        while True:
            assert len(self.data) == len(ready_env_ids)
            # restore the state: if the last state is None, it won't store
            last_state = self.data.policy.pop("hidden_state", None)
            # get the next action
            if demonstrate:
                env_render_cur_frames = self.env.render()
                act_sample = self.demonstrate_override(env_render_cur_frames)
                self.data.update(act=act_sample)
            elif random:
                 # policy handles goal sampling and resampling
                act_sample = self.policy.random_sample(num = len(ready_env_ids))
                self.data.update(act=act_sample)
            else:

                if no_grad:
                    with torch.no_grad():  # faster than retain_grad version
                        # self.data.obs will be used by agent to get result (handles goal sampling)
                        result = self.policy(self.data, last_state)
                else:
                    result = self.policy(self.data, last_state)
                # update state / act / policy into self.data
                policy = result.get("policy", Batch())
                assert isinstance(policy, Batch)
                state = result.get("state", None)
                if state is not None:
                    policy.hidden_state = state  # save state into buffer
                act = to_numpy(result.act)
                if self.exploration_noise:
                    act = self.policy.exploration_noise(act, self.data)

                self.data.update(policy=policy, act=act)

            # get bounded and remapped actions first (not saved into buffer)
            action_remap = self.policy.map_action(self.data.act)

            # step in env
            obs_next, rew, terminated, truncated, info = self.env.step(
                action_remap,  # type: ignore
                ready_env_ids
            )
            done = np.logical_or(terminated, truncated)

            self.data.update(
                obs_next=obs_next,
                rew=rew,
                terminated=terminated,
                truncated=truncated,
                done=done,
                info=info
            )
            if self.preprocess_fn:
                self.data.update(
                    self.preprocess_fn(
                        obs_next=self.data.obs_next,
                        rew=self.data.rew,
                        done=self.data.done,
                        info=self.data.info,
                        policy=self.data.policy,
                        env_id=ready_env_ids,
                        act=self.data.act,
                    )
                )

            # ---------------------------------- major modification starts ---------------------------------- #

            # obs_next for the upper policy
            self.data.update(env_rew=rew)

            # Initialize and update primitive values first
            # desired target goal: the goal given by the upper action
            # time_lower: updates the time since last sample for the upper policy,
            #             used by rrt to calculate lower policy truncation (timeout)
            #             (length n vector where n is the number in the lower hierarchy)
            ep_time = self.data.time * (1 - done.astype(int)) + 1  # TODO: might be last_new_action
            self.data.obs.reached_graph_counter = self.data.obs.reached_graph_counter * (1 - done.astype(int))

            # achieved target goal: the actual goal reached at the current time step
            # achieved target goal computation uses (obs, graph, obs_next)
            # achieved target goal must be computed before the terminate and reward chains, but after graph 
            self.data.obs_next = self.normalize_batch(construct_her_obs(achieved_goal=self.data.info.achieved_goal,
                                                   desired_goal=self.data.info.desired_goal,
                                                   env_obs=obs_next,
                                                   reached_graph_counter=self.data.obs.reached_graph_counter,))

            # target, next_target, target_diff: gets the factored state for achieved goals for all factors
            # used by rrt to calculated reward & lower policy termination, and lower policy to extract achieved goal
            target, next_target = self.policy.get_target(self.data), self.policy.get_target(self.data, next=True)
            target_diff = next_target - target

            # true_graph: the true graph connectivity
            # graph: the connectivity of the graph
            # graph may require the target/next_target/target_diff
            # used by rrt to calculated reward & lower policy termination, and lower policy to extract achieved goal
            self.data.true_graph = self.data.info.factor_graph
            # print(rew, self.data.obs.reached_graph_counter, self.data.true_graph)
            true_graph_count_idx = np.dot(self.data.info.factor_graph, self.graph_to_count_idx)
            graph = self.dynamics(self.data)
            graph_count_idx = np.dot(graph, self.graph_to_count_idx)


            self.data.update(
                time=ep_time,
                target=target,
                next_target=next_target,
                target_diff=target_diff,
                graph=graph,
                graph_count_idx=graph_count_idx,
                true_graph_count_idx=true_graph_count_idx
            )


            # record reward, terminate is goal reaching, truncate is timeout reached
            # Overwrites reward from the environment with the goal conditioned one
            reward, term, trunc = self.policy.check_rew_term_trunc(self.data)

            self.data.update(
                rew=reward,
                # TODO: in wider rew/terminated situations, this might not work and rew should be overwritten
                terminated=term,
                truncated=trunc,
            )
            # print(self.data.obs.achieved_goal[0]-self.data.obs.desired_goal[0], self.data.obs_next.achieved_goal[0], self.data.info.achieved_goal[0], self.data.obs.desired_goal[0], reward[0], term[0], trunc[0])

            # reached_graph indicates if a nontrivial graph was acheived
            # reached_goal indicates if the desired goal was reached
            # these values are mostly used for logging, but "reached" is defined by the reward function
            reached_graph, reached_goal = self.policy.rewtermdone.check_reached(self.data)
            self.data.reached = Batch(reached_graph=reached_graph, reached_goal=reached_goal, updated=np.ones_like(reached_graph, dtype=bool))
            self.data.obs_next.reached_graph_counter += reached_graph / self.timeout

            if self.render and not save_gif:
                img = self.env.render()
                for i in range(self.env_num):
                    imio.imsave(os.path.join(save_frame_dir, str(i), "state" + str(step_count // self.env_num) + ".png"), img[i])

                if show_frame:
                    cv2.imshow("frame", img[0])
                    cv2.waitKey(10)
                # if render > 0 and not np.isclose(render, 0):
                #     time.sleep(render)

            # GIF EDIT
            # if a graph is reached, we can render a gif of the states when that happens 
            gc_done = done | term | trunc
            if save_gif:
                env_render_cur_frames = self.env.render()
                for i in range(min(num_of_gifs, len(ready_env_ids))):
                    frames[i].append(np.moveaxis(env_render_cur_frames[i], 2, 0))
                    if gc_done[i]:
                        # TODO: we can also stop when low_done
                        # TODO: only for printer; this is assuming that graph doesn't change
                        print("Saving gif...")
                        # factor = self.data.obs.desired_goal[i][:15].reshape(5, 3).argmax(-1)
                        # z = self.data.obs.desired_goal[i][15:].argmax(-1)
                        gif_name = f"ep:{self.epoch_num}__count:{episode_count}_{i}.gif"
                        print(self.exp_path)
                        write_gif(np.array(frames[i]),
                                  str(self.exp_path) + "/gifs/" + gif_name,
                                  fps=1 / pause)
                        frames[i] = []

            # add data into the buffer
            # add data into the lower buffer
            # if np.any(self.data.rew > -0.5): print(self.data.obs.achieved_goal, self.data.obs.desired_goal, self.data.rew)
            ptr, ep_rew, ep_len, ep_idx = self.buffer.add(
                self.data, buffer_ids=ready_env_ids
            )
            if self.separate_her and self.buffer.use_her: self.add_to_temp_buffer(self.data)
            # print("convert graph to count", self.buffer._meta.graph_count_idx.shape, graph_count_idx.shape, self.data.info.factor_graph.shape, self.graph_to_count_idx.shape, true_graph_count_idx.shape)
            if self.policy.training:
                # history is for the history of graphs, but we don't need the history quite as much
                # since our goals are not graph dependent
                self.policy.update_history(self.buffer)
                # updates the state visitation information
                self.policy.update_state_counts(self.data)

            # logging
            # most of the logging happens in record_reset_logging_metrics
            # for online logging components
            self.update_logging_metrics(self.data, ready_env_ids)
            gc_done = term | trunc
            # print(self.data.rew, gc_done, self.data.obs.achieved_goal, self.data.obs.desired_goal)
            if np.any(gc_done):
                env_ind_local = np.where(gc_done)[0]
                env_ind_global = ready_env_ids[env_ind_local]
                global_metrics = self.record_reset_logging_metrics(global_metrics, env_ind_local, env_ind_global)

            # ---------------------------------- major modification ends ---------------------------------- #

            # collect statistics
            step_count += len(ready_env_ids)
            per_goal_step_count = [per_goal_step_count[i] + 1 for i in range(self.env_num)]
            reset_done = done | gc_done
            if demonstrate: print(n_step, step_count, n_episode, episode_count, her_add_statistics, self.epi_logging_metrics)
            if np.any(reset_done):
                env_ind_local = np.where(reset_done)[0]
                env_ind_global = ready_env_ids[env_ind_local]
                episode_count += len(env_ind_local)
                episode_lens.append(ep_len[env_ind_local])
                episode_rews.append(ep_rew[env_ind_local])
                episode_start_indices.append(ep_idx[env_ind_local])


                # this should only have an effect if adding a full HER trajectory
                if self.separate_her and self.buffer.use_her: 
                    new_statistics = add_her_trajectory(self.buffer, self.temp_her_buffers, self.temp_her_buffer_ptrs, env_ind_local, 
                                                        self.her_trajectory_check, check_rew=self.policy.rewtermdone.check_rew, check_term=self.policy.rewtermdone.check_term,
                                                        num_samples= self.num_her_resamples, her_traj_length=self.her_traj_length, use_lowest_post=self.use_lowest_post)
                    her_add_statistics = self.update_her_add_statistics(her_add_statistics, new_statistics)
                # use and update reset_invalid_ids properly
                # This is the primary entry point for logging relevant information
                epi_metrics = self.record_reset_all_logging_metrics(epi_metrics, env_ind_local, env_ind_global)

                # now we copy obs_next to obs, but since there might be
                # finished episodes, we have to reset finished envs first.
                # reset_env_with_ids also sets invalid env ids properly

                if goal_conditioned:
                    # Reset the environment
                    gym_reset_kwargs = []
                    for i in range(self.env_num):
                        if per_goal_step_count[i] >= n_steps_per_goal:
                            gym_reset_kwargs.append({'options':{'goal': None}})
                            per_goal_step_count[i] = 0
                        else:
                            gym_reset_kwargs.append({'options': {'goal': self.goals[i]}})
                self._reset_env_with_ids(
                    env_ind_local, env_ind_global, gym_reset_kwargs
                )
                for i in env_ind_local:
                    self._reset_state(i)

                # remove surplus env id from ready_env_ids
                # to avoid bias in selecting environments
                if n_episode:
                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
                    if surplus_env_num > 0:
                        mask = np.ones_like(ready_env_ids, dtype=bool)
                        mask[env_ind_local[:surplus_env_num]] = False
                        ready_env_ids = ready_env_ids[mask]
                        self.data = self.data[mask]


            self.data.obs = self.data.obs_next

            if (n_step and step_count >= n_step) or \
                    (n_episode and episode_count >= n_episode):
                break

        # generate statistics
        self.collect_step += step_count
        self.collect_episode += episode_count
        self.collect_time += max(time.time() - start_time, 1e-9)

        if n_episode:
            self.data = Batch(
                obs={},
                act={},
                rew={},
                terminated={},
                truncated={},
                done={},
                obs_next={},
                info={},
                policy={}
            )
            self.reset_env()

        if episode_count > 0:
            rews, lens, idxs = list(
                map(
                    np.concatenate,
                    [episode_rews, episode_lens, episode_start_indices]
                )
            )
            rew_mean, rew_std = rews.mean(), rews.std()
            len_mean, len_std = lens.mean(), lens.std()
        else:
            rews, lens, idxs = np.array([]), np.array([], int), np.array([], int)
            rew_mean = rew_std = len_mean = len_std = 0
        result = {
                "n/ep": episode_count,
                "n/st": step_count,
                "rews": rews,
                "lens": lens,
                "idxs": idxs,
                "rew": rew_mean,
                "len": len_mean,
                "rew_std": rew_std,
                "len_std": len_std,
            }
        if self.separate_her and self.buffer.use_her: 
            her_add_statistics = her_summary_statistics(her_add_statistics)

            result = {**result, **her_add_statistics}

        # extra logging, default value np.nan will not be recorded by the logger
        result.update({f"episode/{k}": np.mean(v) if len(v) else np.nan for k, v in epi_metrics.items()})
        # result.update({f"goal/{k}": np.mean(v) if len(v) else np.nan for k, v in gc_metrics.items()})

        return result
